import copy

import torch
from torch import nn

from algos.networks import MLP
from tools.utils import soft_update


class Info():
    def __init__(self,mlp,args,context,rollouts):
        self.context=context
        self.rollouts=rollouts
        self.args=args

        self.mlp = mlp
        self.target_mlp = copy.deepcopy(self.mlp)

        self.optimizer=torch.optim.Adam(self.mlp.parameters(),lr=args.lr_pred,weight_decay=args.weight_decay)
        self.batch_size=args.batch_size

        self.save_sum_new_loss=torch.zeros((1,))
        self.mean_ch_target=torch.zeros((1,))

        self.store=None
        self.num_updates =0

        self.disc_updates = 0
        self.all_objective=0
        self.cpt= 0
        self.interval_gradient = 10
        self.log_gradients = torch.zeros((self.interval_gradient,))
        self.eval_mode=False


    def label_embed(self,input,act=False):
        return self.target_embed(input)

    def goal_embed(self,input,act=False):
        return self.label_embed(input)

    def compute_feedback(self,out,target):
        return -torch.norm(out-target,2,dim=1)

    def evaluate(self, inputs, goals):
        inputs = inputs.to(self.args.device)
        goals =goals.to(self.args.device)
        embedding=self.label_embed(inputs,act=True)
        evalu = self.compute_feedback(goals,embedding)
        return evalu

    def step(self,**kwargs):
        pass


    def after_update(self):

        if self.rollouts.can_learn() and not (self.eval_mode):
            batch = self.rollouts.get_evals()
            masks = torch.ones_like(batch.distribution_id, dtype=torch.bool)
            if not self.args.ratio_for_predictor:
                masks = masks & ~batch.sac_train
            if masks.any() and not (self.eval_mode):
                self.rollouts.insert_goal_data(batch,masks,self.feed)


    def check_gradient(self):
        rest_log = (self.context.mpolicy.log_interval- self.context.mpolicy.total_num_steps%self.context.mpolicy.log_interval)
        if rest_log < self.interval_gradient:
            with torch.no_grad():
                grad_mlp = self.mlp.get_gradient()
                self.log_gradients[rest_log] = grad_mlp

    def print(self,*args,**kwargs):
        pass

    def eval(self):
        self.eval_mode=True

    def train(self):
        self.eval_mode=False

    def load(self):
        if self.context.load_model_path:
            path = self.context.load_model_path  + "Predictor.pt"
            checkpoint=torch.load(path,map_location=torch.device('cpu'if not CUDA else "cuda:0") )
            self.mlp.load_state_dict(checkpoint['network_state_dict'])
            self.optimizer.load_state_dict(checkpoint['network_optimizer_state_dict'])
            for param_group in self.optimizer.param_groups:
                param_group["lr"]=self.args.lr_pred

            if self.args.clone_negative:
                self.target_mlp = copy.deepcopy(self.mlp)


    def save(self):
        if self.context.save_model:
            path = self.context.path_models+"Predictor.pt"
            obj={}
            obj["network_state_dict"] = self.mlp.state_dict()
            obj['network_optimizer_state_dict']=self.optimizer.state_dict()
            torch.save(obj, path)


class DistopInfo(Info):
    """
    Standard predictor which take the predictors from the buffer and use it to elarn them and create the feedback for policies
    """
    #
    def __init__(self, mlp, args,context, rollouts):
        super().__init__(mlp, args, context, rollouts)
        if self.args.type == 2:
            self.mlp_pred = MLP(num_inputs=self.args.num_latents+context.action_size, num_output=self.args.num_latents,activation=nn.ELU,num_layers=2, hidden_size=128).to(args.device)
            self.optimizer_pred=torch.optim.Adam(self.mlp_pred.parameters(),lr=5e-4,weight_decay=1e-6)
        if self.args.type == 3:
            self.mlp_pred1 = MLP(num_inputs=self.mlp.num_inputs, num_output=10,activation=nn.ELU,num_layers=2, hidden_size=128).to(args.device)
            self.mlp_pred = MLP(num_inputs=self.mlp.num_inputs, num_output=10,activation=nn.ELU,num_layers=2, hidden_size=128).to(args.device)
            self.optimizer_pred=torch.optim.Adam(self.mlp_pred.parameters(),lr=1e-4,weight_decay=1e-6)

    def embed(self,inputs,**kwargs):
        return self.mlp(inputs[:, self.context.start_discriminator:])

    def target_embed(self,inputs):
        with torch.no_grad():
            e= self.target_mlp(inputs[:,self.context.start_discriminator:])
        return e

    def learn(self):
        if not (self.rollouts.can_learn()) or self.eval_mode:
            return
        batch = self.rollouts.get_evals()
        inputs=batch.next_obs
        prev_inputs=batch.obs

        self.num_updates +=1
        with torch.no_grad():
            batch.goals = self.goal_embed(batch.goals_obs)
        batch.label_goals =batch.goals.clone().detach().squeeze()
        
        inputs=inputs.clone()
        prev_inputs = prev_inputs.clone()

        # if not self.args.separate_learning or self.args.embed_goal== 0 or self.args.target_prev_only:
        pos_embeddings=self.embed(inputs,store=True)
        prev_embeddings = self.embed(prev_inputs)
        local_distance = torch.norm(prev_embeddings - pos_embeddings, 2, dim=1)

        ###Target embedding for DRL algorithm and OEGN
        batch.prev_embeddings = self.label_embed(prev_inputs)
        batch.embeddings = self.label_embed(inputs)


        batch.pi_state_embeddings = batch.embeddings
        batch.pi_prev_state_embeddings = batch.prev_embeddings
        batch.pi_goal_embeddings = batch.label_goals

        batch.som_embeddings = batch.embeddings
        batch.som_prev_embeddings = batch.prev_embeddings
        target_embedding = batch.goals

        ###Relabeling mask
        masks=batch.distribution_id.clone()

        ###Learning mask for predictor
        if self.args.data_for_predictor:
            masks2 =masks
        else:
            masks2=torch.ones((pos_embeddings.shape[0],),dtype=torch.bool)


        ###To display on OEGN network
        self.store = pos_embeddings.detach()
        self.prev_store = prev_embeddings.detach()
        
        
        weights=1
        local_distance_masked = local_distance[masks2]
        ###We construct the obejctive function
        if local_distance_masked.shape[0] > 0 and (not self.args.stop_after_warmup or self.context.mpolicy.total_num_steps < self.args.warmup):
            local_distance_masked = local_distance[masks2]

            ###Close constraint
            if self.args.square:
                newscore = self.args.reg_coef * torch.nn.functional.relu(local_distance_masked.view(-1, 1)**2 - self.args.reg_pred**2).sum(dim=1)
            else:
                newscore = self.args.reg_coef * torch.nn.functional.relu(local_distance_masked.view(-1, 1) - self.args.reg_pred).sum(dim=1)
            self.all_objective = self.all_objective+ (newscore*weights).sum()

            ###Consistency contraint
            negative_target=batch.embeddings[masks2]
            self.all_objective = self.all_objective + self.args.coef_clone_negative * torch.nn.functional.mse_loss(pos_embeddings[masks2],negative_target, reduction="sum")
            self.mean_ch_target = torch.norm(pos_embeddings[masks2]-negative_target,2,dim=1).mean().detach()


            positive_example = pos_embeddings[masks2].unsqueeze(1)
            pose = positive_example
            table = []

            ###Build negatives examples
            aranged = torch.arange(0, pose.shape[0])
            for i in range(self.args.number_negative):
                indices2 = torch.randint(0, pose.shape[0], (pose.shape[0],))
                m = (indices2 ==aranged)
                indices2[m] = (indices2[m] + 1) % pose.shape[0]
                samples = pose[indices2]
                table.append(samples)
            embeddings = torch.cat(table, dim=1)
            ###Contrastive part
            self.all_objective = self.all_objective - self.compute_negative_loss(positive_example,embeddings,weights)

        ###Update the network
        self.cpt += local_distance_masked.shape[0]
        if self.cpt >= 0.9*self.batch_size :
            if (not self.args.stop_after_warmup or self.context.mpolicy.total_num_steps < self.args.warmup):
                self.optimizer.zero_grad()
                (self.all_objective/self.cpt) .backward()
                self.check_gradient()
                self.optimizer.step()
                soft_update(self.target_mlp, self.mlp, self.args.clone_negative)
            self.disc_updates += 1
            self.cpt=0
            self.all_objective=0

            # if self.args.embed_sac:
        feed = torch.norm(target_embedding - batch.pi_state_embeddings, 2, dim=1)[:self.batch_size]
        ###Relabeling

        with torch.no_grad():
            data_relabel = batch.pi_state_embeddings
            feedback = self.compute_feedback(data_relabel,target_embedding)

            if masks.any() and self.args.type == 0:
                ###Relabel samples
                if self.args.relabeling == 1 or self.args.relabeling == 2 or self.args.relabeling == 4:
                    not_learned = masks if not self.args.relabeling2 else torch.ones_like(masks,dtype=torch.bool,device=masks.device)
                    need_relabel = data_relabel[not_learned]

                    new_labels = self.goal_embed(batch.label_obs[not_learned])
                    batch.pi_goal_embeddings[not_learned] = new_labels
                    feedback[not_learned] = self.compute_feedback(need_relabel,new_labels)
                elif self.args.relabeling == 3:  # take relabeling goals in B^S
                    not_learned = masks
                    need_relabel = data_relabel[not_learned]
                    indices = torch.randint(0, need_relabel.shape[0], (need_relabel.shape[0],), dtype=torch.long)
                    newlabels = need_relabel[indices]
                    batch.pi_goal_embeddings[not_learned] = newlabels
                    feedback[not_learned] = self.compute_feedback(newlabels, need_relabel)
                batch.irewards += self.args.reward_coef * feedback.detach().unsqueeze(1)
        if self.args.type == 1 or self.args.type == 4:
            batch.irewards += self.args.reward_coef * batch.densities.detach().unsqueeze(1) + batch.rewards.detach()
        if self.args.type == 2:
            predictions_val = self.mlp_pred(torch.cat((batch.prev_embeddings.detach(),batch.actions.clone()),dim=1))
            # predictions_val = self.mlp_pred(torch.zeros_like(torch.cat((batch.prev_embeddings.detach(),batch.actions),dim=1)))
            # loss = torch.nn.functional.mse_loss(predictions_val,batch.embeddings.detach(),reduction="none").sum(dim=1)
            loss = torch.pow(predictions_val-batch.embeddings.detach(),2).sum(dim=1)
            self.optimizer_pred.zero_grad()
            loss.mean(dim=0).backward()
            self.optimizer_pred.step()
            batch.irewards += self.args.reward_coef * loss.detach().unsqueeze(1) + batch.rewards.detach()
        if self.args.type == 3:
            predictions_val = self.mlp_pred(inputs[:,self.context.start_discriminator:])
            predictions_val2 = self.mlp_pred1(inputs[:,self.context.start_discriminator:])
            loss = torch.nn.functional.mse_loss(predictions_val,predictions_val2.detach(),reduction="none").sum(dim=1)
            self.optimizer_pred.zero_grad()
            loss.mean(dim=0).backward()
            self.optimizer_pred.step()
            batch.irewards += self.args.reward_coef * loss.detach().unsqueeze(1) + batch.rewards.detach()
        self.feed = feed

    def compute_negative_loss(self,positive_example,embeddings,weights):

        norm1 = torch.zeros(positive_example.shape[0], 1, device=positive_example.device)
        norm_tmp = torch.norm(positive_example - embeddings[:, 0:], 2, dim=2)

        exp_norm = torch.exp(-self.args.tau_negative * norm1)
        exp_norm_tmp = torch.exp(-self.args.tau_negative * norm_tmp)
        all_exp_norm = torch.cat((exp_norm, exp_norm_tmp), dim=1)

        p = exp_norm.squeeze() / all_exp_norm.sum(dim=1)
        new_loss = torch.log(p)
        sum_new_loss = (weights*new_loss).sum()

        self.save_sum_new_loss = sum_new_loss.detach()
        return self.args.coef_negative * sum_new_loss
